from math import log
import operator


# 计算信息熵
def cal_entropy(data):
    num = len(data)  # 样本数
    label = {}
    for feat in data:  # 遍历每个样本
        current = feat[-1]  # 当前样本的类别
        if current not in label.keys():  # 生成类别字典
            label[current] = 0
        label[current] += 1
    entropy = 0.0
    for key in label:  # 计算信息熵
        prob = float(label[key]) / num
        entropy = entropy - prob * log(prob, 2)
    return entropy


# 划分数据集
def split_data(data, axis, v):
    ret = []
    for feat in data:
        if feat[axis] == v:
            reduced = feat[:axis]
            reduced.extend(feat[axis + 1:])
            ret.append(reduced)
    return ret


# 选择最好的数据集划分方式
def best_split(data):
    num = len(data[0]) - 1  # 属性的个数
    base = cal_entropy(data)
    best = 0.0
    ret = -1
    for i in range(num):  # 对每个属性技术信息增益
        feat_list = [example[i] for example in data]
        val = set(feat_list)  # 该属性的取值集合
        entropy = 0.0
        for v in val:  # 对每一种取值计算信息增益
            subset = split_data(data, i, v)
            prob = len(subset) / float(len(data))
            entropy += prob * cal_entropy(subset)
        info = base - entropy
        if info > best:  # 选择信息增益最大的属性
            best = info
            ret = i
    return ret


# 通过排序返回出现次数最多的类别
def best_cnt(class_list):
    class_cnt = {}
    for vote in class_list:
        if vote not in class_cnt.keys():
            class_cnt[vote] = 0
        class_cnt[vote] += 1
    sorted_class = sorted(class_cnt.items(),
                          key=operator.itemgetter(1), reverse=True)
    return sorted_class[0][0]


# 递归构建决策树
def create_tree(data, labels):
    class_list = [example[-1] for example in data]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0]
    if len(data[0]) == 1:
        return best_cnt(class_list)
    best = best_split(data)  # 最优划分属性的索引
    label = labels[best]  # 最优划分属性的标签
    tree = {label: {}}
    del (labels[best])  # 已经选择的特征不再参与分类
    feat_val = [example[best] for example in data]
    val = set(feat_val)
    for v in val:  # 对每个分支递归构建树
        sub_label = labels[:]
        tree[label][v] = create_tree(
            split_data(data, best, v), sub_label)
    return tree
